{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Planar flow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required packages\n",
    "import torch\n",
    "import numpy as np\n",
    "import normflows as nf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "K = 16\n",
    "#torch.manual_seed(0)\n",
    "\n",
    "# Move model on GPU if available\n",
    "enable_cuda = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
    "\n",
    "flows = []\n",
    "for i in range(K):\n",
    "    flows += [nf.flows.Planar((2,))]\n",
    "target = nf.distributions.TwoModes(2, 0.1)\n",
    "\n",
    "q0 = nf.distributions.DiagGaussian(2)\n",
    "nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=target)\n",
    "nfm.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Plot target distribution\n",
    "grid_size = 200\n",
    "xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))\n",
    "z = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)\n",
    "log_prob = target.log_prob(z.to(device)).to('cpu').view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.pcolormesh(xx, yy, prob)\n",
    "plt.show()\n",
    "\n",
    "# Plot initial flow distribution\n",
    "z, _ = nfm.sample(num_samples=2 ** 20)\n",
    "z_np = z.to('cpu').data.numpy()\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (grid_size, grid_size), range=[[-3, 3], [-3, 3]])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "max_iter = 20000\n",
    "num_samples = 2 * 20\n",
    "anneal_iter = 10000\n",
    "annealing = True\n",
    "show_iter = 2000\n",
    "\n",
    "\n",
    "loss_hist = np.array([])\n",
    "\n",
    "optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-3, weight_decay=1e-4)\n",
    "for it in tqdm(range(max_iter)):\n",
    "    optimizer.zero_grad()\n",
    "    if annealing:\n",
    "        loss = nfm.reverse_kld(num_samples, beta=np.min([1., 0.01 + it / anneal_iter]))\n",
    "    else:\n",
    "        loss = nfm.reverse_kld(num_samples)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
    "    \n",
    "    # Plot learned distribution\n",
    "    if (it + 1) % show_iter == 0:\n",
    "        torch.cuda.manual_seed(0)\n",
    "        z, _ = nfm.sample(num_samples=2 ** 20)\n",
    "        z_np = z.to('cpu').data.numpy()\n",
    "        \n",
    "        plt.figure(figsize=(10, 10))\n",
    "        plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (grid_size, grid_size), range=[[-3, 3], [-3, 3]])\n",
    "        plt.show()\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.plot(loss_hist, label='loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Plot learned distribution\n",
    "z, _ = nfm.sample(num_samples=2 ** 20)\n",
    "z_np = z.to('cpu').data.numpy()\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (grid_size, grid_size), range=[[-3, 3], [-3, 3]])\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
